import torchvision.transforms as T
import torchvision.datasets as datasets
import torch
import os


def choose_dataset(dataset_name: str, batch_size: int, num_workers: int, datapath: str):
    """
    selects a dataset by name
    """
    if dataset_name == 'cifar10':
        return load_cifar10(batch_size, num_workers)
    elif dataset_name == 'fashion_mnist':
        return load_fashion_mnist(batch_size, num_workers)
    elif dataset_name == 'mnist':
        return load_mnist(batch_size, num_workers)
    else:
        print("dataset not available. Exiting")
        exit(1)



def load_cifar10(batch_size: int, num_workers: int):
    """
    returns train_loader, val_loader and test_loader for the data set
    """
    transform = T.Compose([T.ToTensor()])
    train_loader = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    val_loader = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_loader, batch_size=batch_size, shuffle=True,
                                             num_workers=num_workers)
    return train_loader, val_loader, None


def load_fashion_mnist(batch_size: int, num_workers: int):
    """
            returns train_loader, val_loader and test_loader for the data set
        """
    trans = T.Compose([T.ToTensor()])
    train_loader = datasets.FashionMNIST(root='./data', train=True, download=True, transform=trans)
    val_loader = datasets.FashionMNIST(root='./data', train=False, download=True, transform=trans)
    train_loader = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_loader, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    return train_loader, val_loader, None



def load_mnist(batch_size: int, num_workers: int):
    """
            returns train_loader, val_loader and test_loader for the data set
    """

    trans = T.Compose([T.ToTensor()])
    train_loader = datasets.MNIST(root='./data', train=True, download=True, transform=trans)
    val_loader = datasets.MNIST(root='./data', train=False, download=True, transform=trans)
    train_loader = torch.utils.data.DataLoader(train_loader, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_loader, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    return train_loader, val_loader, None
